# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

#' @include arrow-object.R

# Base class for Array, ChunkedArray, and Scalar, for S3 method dispatch only.
# Does not exist in C++ class hierarchy
ArrowDatum <- R6Class("ArrowDatum",
  inherit = ArrowObject,
  public = list(
    cast = function(target_type, safe = TRUE, ...) {
      opts <- cast_options(safe, ...)
      opts$to_type <- as_type(target_type)
      call_function("cast", self, options = opts)
    },
    SortIndices = function(descending = FALSE) {
      assert_that(is.logical(descending))
      assert_that(length(descending) == 1L)
      assert_that(!is.na(descending))
      call_function(
        "sort_indices",
        self,
        options = list(names = "", orders = as.integer(descending))
      )
    }
  )
)

#' @export
length.ArrowDatum <- function(x) x$length()

#' @export
is.finite.ArrowDatum <- function(x) {
  is_fin <- call_function("is_finite", x)
  # for compatibility with base::is.finite(), return FALSE for NA_real_
  is_fin & !is.na(is_fin)
}

#' @export
is.infinite.ArrowDatum <- function(x) {
  is_inf <- call_function("is_inf", x)
  # for compatibility with base::is.infinite(), return FALSE for NA_real_
  is_inf & !is.na(is_inf)
}

#' @export
is.na.ArrowDatum <- function(x) {
  call_function("is_null", x, options = list(nan_is_null = TRUE))
}

#' @export
is.nan.ArrowDatum <- function(x) {
  if (x$type_id() %in% TYPES_WITH_NAN) {
    # TODO(ARROW-13366): if an option is added to the is_nan kernel to treat NA
    # as NaN, use that to simplify the code here
    call_function("is_nan", x) & call_function("is_valid", x)
  } else {
    Scalar$create(FALSE)$as_array(length(x))
  }
}

#' @export
as.vector.ArrowDatum <- function(x, mode) {
  x$as_vector()
}

#' @export
Ops.ArrowDatum <- function(e1, e2) {
  if (missing(e2)) {
    switch(.Generic,
      "!" = return(eval_array_expression(.Generic, e1)),
      "+" = return(eval_array_expression(.Generic, 0L, e1)),
      "-" = return(eval_array_expression("negate_checked", e1)),
    )
  }

  switch(.Generic,
    "+" = ,
    "-" = ,
    "*" = ,
    "/" = ,
    "^" = ,
    "%%" = ,
    "%/%" = ,
    "==" = ,
    "!=" = ,
    "<" = ,
    "<=" = ,
    ">=" = ,
    ">" = ,
    "&" = ,
    "|" = {
      eval_array_expression(.Generic, e1, e2)
    },
    stop(paste0("Unsupported operation on `", class(e1)[1L], "` : "), .Generic, call. = FALSE)
  )
}

#' @export
Math.ArrowDatum <- function(x, ..., base = exp(1), digits = 0) {
  switch(.Generic,
    abs = eval_array_expression("abs_checked", x),
    ceiling = eval_array_expression("ceil", x),
    sign = eval_array_expression("sign", x),
    floor = eval_array_expression("floor", x),
    trunc = eval_array_expression("trunc", x),
    acos = eval_array_expression("acos_checked", x),
    asin = eval_array_expression("asin_checked", x),
    atan = eval_array_expression("atan", x),
    cos = eval_array_expression("cos_checked", x),
    sin = eval_array_expression("sin_checked", x),
    tan = eval_array_expression("tan_checked", x),
    log = eval_array_expression("logb_checked", x, base),
    log10 = eval_array_expression("log10_checked", x),
    log2 = eval_array_expression("log2_checked", x),
    log1p = eval_array_expression("log1p_checked", x),
    round = eval_array_expression(
      "round",
      x,
      options = list(ndigits = digits, round_mode = RoundMode$HALF_TO_EVEN)
    ),
    sqrt = eval_array_expression("sqrt_checked", x),
    exp = eval_array_expression("power_checked", exp(1), x),
    cumsum = eval_array_expression("cumulative_sum_checked", x),
    cumprod = eval_array_expression("cumulative_prod_checked", x),
    cummax = eval_array_expression("cumulative_max", x),
    cummin = eval_array_expression("cumulative_min", x),
    signif = ,
    expm1 = ,
    cospi = ,
    sinpi = ,
    tanpi = ,
    cosh = ,
    sinh = ,
    tanh = ,
    acosh = ,
    asinh = ,
    atanh = ,
    lgamma = ,
    gamma = ,
    digamma = ,
    trigamma = ,
    stop(paste0("Unsupported operation on `", class(x)[1L], "` : "), .Generic, call. = FALSE)
  )
}

# Wrapper around call_function that:
# (1) maps R function names to Arrow C++ compute ("/" --> "divide_checked")
# (2) wraps R input args as Array or Scalar
eval_array_expression <- function(FUN,
                                  ...,
                                  args = list(...),
                                  options = empty_named_list()) {
  if (FUN == "-" && length(args) == 1L) {
    if (inherits(args[[1]], "ArrowObject")) {
      return(eval_array_expression("negate_checked", args[[1]]))
    } else {
      return(-args[[1]])
    }
  }
  args <- lapply(args, .wrap_arrow)

  # In Arrow, "divide" is one function, which does integer division on
  # integer inputs and floating-point division on floats
  if (FUN == "/") {
    # TODO: omg so many ways it's wrong to assume these types
    args <- map(args, ~ .$cast(float64()))
  } else if (FUN == "%/%") {
    # In R, integer division works like floor(float division)
    out <- eval_array_expression("/", args = args)

    # integer output only for all integer input
    int_type_ids <- Type[toupper(INTEGER_TYPES)]
    numerator_is_int <- args[[1]]$type_id() %in% int_type_ids
    denominator_is_int <- args[[2]]$type_id() %in% int_type_ids

    if (numerator_is_int && denominator_is_int) {
      out_float <- eval_array_expression(
        "if_else",
        eval_array_expression("equal", args[[2]], 0L),
        Scalar$create(NA_integer_),
        eval_array_expression("floor", out)
      )
      return(out_float$cast(args[[1]]$type))
    } else {
      return(eval_array_expression("floor", out))
    }
  } else if (FUN == "%%") {
    # We can't simply do {e1 - e2 * ( e1 %/% e2 )} since Ops.Array evaluates
    # eagerly, but we can build that up
    quotient <- eval_array_expression("%/%", args = args)
    base <- eval_array_expression("*", quotient, args[[2]])
    # this cast is to ensure that the result of this and e1 are the same
    # (autocasting only applies to scalars)
    base <- base$cast(args[[1]]$type)
    return(eval_array_expression("-", args[[1]], base))
  }

  call_function(
    .array_function_map[[FUN]] %||% FUN,
    args = args,
    options = options
  )
}

.wrap_arrow <- function(arg) {
  if (!inherits(arg, "ArrowObject")) {
    arg <- Scalar$create(arg)
  }
  arg
}

#' @export
na.omit.ArrowDatum <- function(object, ...) {
  object$Filter(!is.na(object))
}

#' @export
na.exclude.ArrowDatum <- na.omit.ArrowDatum

#' @export
na.fail.ArrowDatum <- function(object, ...) {
  if (object$null_count > 0) {
    stop("missing values in object", call. = FALSE)
  }
  object
}

filter_rows <- function(x, i, keep_na = TRUE, ...) {
  # General purpose function for [ row subsetting with R semantics
  # Based on the input for `i`, calls x$Filter, x$Slice, or x$Take
  nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like
  if (is.logical(i)) {
    if (isTRUE(i)) {
      # Shortcut without doing any work
      x
    } else {
      i <- rep_len(i, nrows) # For R recycling behavior; consider vctrs::vec_recycle()
      x$Filter(i, keep_na)
    }
  } else if (is.numeric(i)) {
    if (all(i < 0)) {
      # in R, negative i means "everything but i"
      i <- setdiff(seq_len(nrows), -1 * i)
    }
    if (is.sliceable(i)) {
      x$Slice(i[1] - 1, length(i))
    } else if (all(i > 0)) {
      x$Take(i - 1)
    } else {
      stop("Cannot mix positive and negative indices", call. = FALSE)
    }
  } else if (is.Array(i, INTEGER_TYPES)) {
    # NOTE: this doesn't do the - 1 offset
    x$Take(i)
  } else if (is.Array(i, "bool")) {
    x$Filter(i, keep_na)
  } else {
    # Unsupported cases
    if (is.Array(i)) {
      stop("Cannot extract rows with an Array of type ", i$type$ToString(), call. = FALSE)
    }
    stop("Cannot extract rows with an object of class ", class(i), call. = FALSE)
  }
}

#' @export
`[.ArrowDatum` <- filter_rows

#' @importFrom utils head
#' @export
head.ArrowDatum <- function(x, n = 6L, ...) {
  assert_is(n, c("numeric", "integer"))
  assert_that(length(n) == 1)
  len <- NROW(x)
  if (n < 0) {
    # head(x, negative) means all but the last n rows
    n <- max(len + n, 0)
  } else {
    n <- min(len, n)
  }
  if (!is.integer(n)) {
    n <- floor(n)
  }
  if (n == len) {
    return(x)
  }
  x$Slice(0, n)
}

#' @importFrom utils tail
#' @export
tail.ArrowDatum <- function(x, n = 6L, ...) {
  assert_is(n, c("numeric", "integer"))
  assert_that(length(n) == 1)
  if (!is.integer(n)) {
    n <- floor(n)
  }
  len <- NROW(x)
  if (n < 0) {
    # tail(x, negative) means all but the first n rows
    n <- min(-n, len)
  } else {
    n <- max(len - n, 0)
  }
  if (n == 0) {
    return(x)
  }
  x$Slice(n)
}

is.sliceable <- function(i) {
  # Determine whether `i` can be expressed as a $Slice() command
  is.numeric(i) &&
    length(i) > 0 &&
    all(i > 0) &&
    i[1] <= i[length(i)] &&
    identical(as.integer(i), i[1]:i[length(i)])
}

#' @export
as.double.ArrowDatum <- function(x, ...) as.double(as.vector(x), ...)

#' @export
as.integer.ArrowDatum <- function(x, ...) as.integer(as.vector(x), ...)

#' @export
as.character.ArrowDatum <- function(x, ...) as.character(as.vector(x), ...)

#' @export
sort.ArrowDatum <- function(x, decreasing = FALSE, na.last = NA, ...) {
  # Arrow always sorts nulls at the end of the array. This corresponds to
  # sort(na.last = TRUE). For the other two cases (na.last = NA and
  # na.last = FALSE) we need to use workarounds.
  # TODO(ARROW-14085): use NullPlacement ArraySortOptions instead of this workaround
  if (is.na(na.last)) {
    # Filter out NAs before sorting
    x <- x$Filter(!is.na(x))
    x$Take(x$SortIndices(descending = decreasing))
  } else if (na.last) {
    x$Take(x$SortIndices(descending = decreasing))
  } else {
    # Create a new array that encodes missing values as 1 and non-missing values
    # as 0. Sort descending by that array first to get the NAs at the beginning
    tbl <- Table$create(x = x, `is_na` = as.integer(is.na(x)))
    tbl$x$Take(tbl$SortIndices(names = c("is_na", "x"), descending = c(TRUE, decreasing)))
  }
}
